Skip to content

[Feature] Support cp.reduce.async.bulk.tensor#1667

Merged
LeiWang1999 merged 10 commits intotile-ai:mainfrom
Rachmanino:tma-reduce
Jan 14, 2026
Merged

[Feature] Support cp.reduce.async.bulk.tensor#1667
LeiWang1999 merged 10 commits intotile-ai:mainfrom
Rachmanino:tma-reduce

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Jan 13, 2026

fix #1655

  • check swizzle support
  • refactor other examples containing tma reduce

Summary by CodeRabbit

  • New Features

    • Added TMA-based atomic-add support and descriptor-based TMA store helpers for improved GPU atomic operations and performance.
  • Tests

    • New tests validating TMA atomic-add correctness and kernel generation; expanded numeric coverage (float32/float16/bfloat16).
  • Bug Fixes

    • Adjusted flash-attention backward postprocessing wiring affecting certain gradient paths.
  • Chores

    • Minor formatting and whitespace cleanups in example files.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 13, 2026

📝 Walkthrough

Walkthrough

Adds TMA-backed AtomicAdd: new layout inference and swizzle logic, CUtensorMap descriptor construction and TMA lowering emitting descriptor-based tma_store_add calls, SM90 device overloads for descriptor-based TMA ops, FlashAttention backward postprocess wiring changes, examples formatting tweaks, and tests validating TMA atomic add.

Changes

Cohort / File(s) Summary
Examples (formatting)
examples/autodd/tilelang_buggy.py, examples/autodd/tilelang_minimized_expected.py
Whitespace and minor string-quote formatting changes; no logic changes.
FlashAttention backward
examples/flash_attention/example_gqa_bwd_tma_reduce.py, examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
Removed layout annotations for dQ/dK/dV and changed where/if postprocessing (mod_post) is created/applied across atomic vs non-atomic branches; varlen variant adjusts dk/dv reduction handling.
AtomicAdd core
src/op/atomic_add.h, src/op/atomic_add.cc
Added ComputeLinearLayout declaration + implementation; TMA-aware InferLayout with swizzle selection; new lowering path that builds CUtensorMap-like descriptors, computes smem offsets, and emits descriptor-based tma_store_add calls; helper utilities added.
CUDA SM90 templates
src/tl_templates/cuda/copy_sm90.h
Added descriptor-based tma_store_add overloads (1D–5D) that emit inline ASM cp.reduce.async.bulk.tensor.*.global.shared::cta.add.bulk_group operations.
Utils and copy changes
src/op/utils.{h,cc}, src/op/copy.cc
Moved/added to_CUtensorMapDataType and ReverseArray helpers into utils.{h,cc}; removed duplicate helpers from copy.cc.
Tests
testing/python/language/test_tilelang_language_atomic_add.py
Added tma_atomic_add_program and test_tma_atomic_add verifying TMA atomic_add (and explicit swizzle); extended atomic tests to cover float16 and bfloat16.

Sequence Diagram(s)

sequenceDiagram
    participant Frontend as Frontend/Op
    participant Infer as InferLayout
    participant Swizzle as Swizzle Resolver
    participant Lower as Lowering
    participant Device as Device ASM (tma_store_add)

    Frontend->>Infer: Request layout for AtomicAdd (use_tma?)
    alt use_tma == true
        Infer->>Swizzle: Probe/build swizzle & layout map
        Swizzle-->>Infer: Selected swizzle or fallback layout
        Infer->>Lower: Provide layout map + swizzle info
        Lower->>Lower: Build CUtensorMap descriptor, compute smem offsets
        Lower->>Device: Emit descriptor-based tma_store_add calls (inline ASM)
    else
        Infer->>Lower: Legacy non-TMA lowering info
        Lower->>Device: Emit legacy atomic_add sequence
    end
    Device-->>Frontend: Kernel source / compiled kernel
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • chengyupku

Poem

🐰 I hop through swizzles, map each byte,
I build descriptors in the night.
I store and add with tiny cheer,
Kernels wake and loudly cheer,
A rabbit dances—GPU delight.

🚥 Pre-merge checks | ✅ 4 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed Title '[Feature] Support cp.reduce.async.bulk.tensor' is specific and directly describes the primary feature addition in the changeset.
Linked Issues check ✅ Passed The PR addresses issue #1655 by adding TMA-based atomic add support with proper swizzle handling across multiple files and tests.
Out of Scope Changes check ✅ Passed All changes are directly related to implementing TMA atomic add support and fixing the swizzle bug. Minor formatting changes in examples are consistent with refactoring efforts mentioned in commit messages.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

🧹 Recent nitpick comments
src/op/atomic_add.cc (1)

317-332: Consider extracting shared ComputeLinearLayout utility.

This implementation is nearly identical to CopyNode::ComputeLinearLayout in src/op/copy.cc (lines 269-284). Both create a tiled layout splitting dimensions into 256-element blocks with the same pattern.

Consider extracting this as a free function in utils.h/utils.cc to avoid duplication:

// In utils.h
Layout ComputeTiledLinearLayout(const Array<PrimExpr>& shape, int tile_size = 256);

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 27f42e1 and 8f320da.

📒 Files selected for processing (4)
  • src/op/atomic_add.cc
  • src/op/copy.cc
  • src/op/utils.cc
  • src/op/utils.h
💤 Files with no reviewable changes (1)
  • src/op/copy.cc
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/op/atomic_add.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/op/atomic_add.cc
📚 Learning: 2026-01-12T07:25:31.685Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.

Applied to files:

  • src/op/atomic_add.cc
🧬 Code graph analysis (2)
src/op/utils.h (1)
src/op/utils.cc (2)
  • to_CUtensorMapDataType (96-157)
  • to_CUtensorMapDataType (96-96)
src/op/atomic_add.cc (4)
src/op/copy.cc (6)
  • ComputeLinearLayout (270-285)
  • ComputeLinearLayout (270-270)
  • InferLayout (289-462)
  • InferLayout (289-290)
  • InferLayout (1731-1734)
  • InferLayout (1731-1732)
src/op/utils.h (1)
  • IsFragmentBuffer (42-44)
src/op/utils.cc (2)
  • to_CUtensorMapDataType (96-157)
  • to_CUtensorMapDataType (96-96)
tilelang/language/builtin.py (1)
  • create_tma_descriptor (104-113)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (5)
src/op/utils.h (1)

9-26: LGTM!

The new include and public API declarations are well-structured. The ReverseArray template provides a clean utility for layout conversions, and the to_CUtensorMapDataType declaration consolidates CUDA type mapping into the common utilities.

src/op/utils.cc (1)

95-157: LGTM!

The centralized type mapping function correctly handles all supported TVM data types. The mapping of int16 to CU_TENSOR_MAP_DATA_TYPE_UINT16 (line 128) aligns with CUDA's TMA API, which doesn't distinguish signed/unsigned for 16-bit types at the descriptor level.

src/op/atomic_add.cc (3)

617-655: LGTM!

The TMA reduce lowering correctly handles both single-call and loop-split cases. The loop splitting at lines 617-639 properly advances both shared memory offset and global coordinates when inner_box_dim exceeds instruction_dim.


444-466: LGTM on TMA descriptor setup.

The TMADesc construction properly validates rank bounds (1-5 dimensions), enforces dtype consistency between buffers, and correctly maps the data type using the centralized utility.


396-407: Non-TMA path layout validation is well-structured.

The check ensuring both fragment buffers have matching layouts when both are present in the layout map provides good runtime validation with a clear error message.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@Rachmanino Rachmanino marked this pull request as ready for review January 14, 2026 02:53
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)

340-343: Bug: dk should be dk_shared for consistency.

Line 343 copies directly from dk fragment to global memory, but the preceding lines establish a pattern of copying through shared memory (dvdv_shareddV, dkdk_shared → should be dK). The varlen counterpart (flashattn_bwd_split in example_gqa_bwd_tma_reduce_varlen.py lines 480-481) correctly uses dk_shared.

Proposed fix
             T.copy(dv, dv_shared)
             T.copy(dv_shared, dV[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
             T.copy(dk, dk_shared)
-            T.copy(dk, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
+            T.copy(dk_shared, dK[bx % groups, bz, by * block_M : (by + 1) * block_M, bx // groups, :])
🤖 Fix all issues with AI agents
In `@src/op/atomic_add.cc`:
- Around line 443-449: The code dereferences as_const_int results for
shared_tensor->shape elements (used to set mat_stride and mat_continuous)
without checking for nullptr, which can crash for symbolic dimensions; update
the block around mat_stride/mat_continuous and the call to
makeGemmABLayoutHopper to first call as_const_int(...) for each dimension,
validate the returned pointer is non-null, and handle the non-constant case
(e.g., choose a safe default, skip swizzle_layout creation, or fall back to a
non-swizzled Layout) so you never dereference a null pointer when constructing
swizzle_layout via makeGemmABLayoutHopper.

In `@testing/python/language/test_tilelang_language_atomic_add.py`:
- Around line 388-399: In test_tma_atomic_add, replace the torch.allclose check
with torch.testing.assert_close to match the file's convention and give better
failure messages when validating `out` against the expected tensor; also fix the
comment typo in the last assertion from "appiled" to "applied" referencing
`kernel` and `kernel_with_explicit_swizzle` to ensure clarity.
- Around line 354-367: The test references tilelang.layout.make_swizzled_layout
in tma_atomic_add_program (the explicit_swizzle branch) but tilelang.layout is
not imported; add an import for the layout submodule (e.g., import
tilelang.layout or from tilelang import layout) at the top of the test file so
tilelang.layout.make_swizzled_layout resolves when explicit_swizzle=True; ensure
the import is placed alongside other tilelang imports and update any existing
import style to match the file.
🧹 Nitpick comments (5)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)

576-582: Consider a dQ-only postprocess function to avoid wasted computation.

The current pattern creates dummy zero tensors for dk and dv (line 581) only to satisfy mod_post's signature, while the second kernel inside flashattn_bwd_postprocess (processing dK/dV) runs but its outputs are discarded. The actual dk and dv come from .sum(0) on line 582.

For consistency with example_mha_bwd_bshd.py and example_mha_bwd_bhsd.py (which have dQ-only postprocess functions), consider refactoring flashattn_bwd_postprocess to handle only dQ, eliminating the unnecessary allocation and kernel execution.

examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)

399-408: Consider a dQ-only postprocess function to avoid wasted computation.

Same pattern as the varlen version: dummy zero tensors are created for dk and dv to call mod_post, but only the dQ result is used. The kernel processing dK/dV inside flashattn_bwd_postprocess executes but its output is discarded since actual dk/dv come from .sum(0).

src/op/atomic_add.h (1)

80-81: Declaration is appropriate; consider deduplication with CopyNode::ComputeLinearLayout.

The method declaration follows the existing pattern. However, the implementation in src/op/atomic_add.cc (lines 386-401) is identical to CopyNode::ComputeLinearLayout in src/op/copy.cc (lines 340-355). Consider extracting this to a shared utility function to avoid code duplication.

testing/python/language/test_tilelang_language_atomic_add.py (1)

3-3: Top-level torch import is good, but redundant imports exist in helper functions.

The top-level import is appropriate for the new test. Note that torch is also imported locally within run_atomic_add, run_tile_atomic_add, etc. Consider removing those redundant local imports for consistency.

src/op/atomic_add.cc (1)

393-400: Consider extracting magic number 256 as a named constant.

The value 256 is used for tiling but its significance isn't documented. A named constant would improve readability and make it easier to adjust if needed.

♻️ Suggested refactor
+// Maximum elements per TMA tile dimension for linear layout
+static constexpr int kTMALinearTileSize = 256;
+
 Layout AtomicAddNode::ComputeLinearLayout(const Buffer &shared_tensor) const {
   // ...
   for (size_t i = 0; i < input_size.size(); i++) {
-    forward_index.push_back(FloorDiv(forward_vars[i], 256));
+    forward_index.push_back(FloorDiv(forward_vars[i], kTMALinearTileSize));
   }
   for (size_t i = 0; i < input_size.size(); i++) {
-    forward_index.push_back(FloorMod(forward_vars[i], 256));
+    forward_index.push_back(FloorMod(forward_vars[i], kTMALinearTileSize));
   }
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 802951e and 5673338.

📒 Files selected for processing (8)
  • examples/autodd/tilelang_buggy.py
  • examples/autodd/tilelang_minimized_expected.py
  • examples/flash_attention/example_gqa_bwd_tma_reduce.py
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
  • src/op/atomic_add.cc
  • src/op/atomic_add.h
  • src/tl_templates/cuda/copy_sm90.h
  • testing/python/language/test_tilelang_language_atomic_add.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • testing/python/language/test_tilelang_language_atomic_add.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/language/test_tilelang_language_atomic_add.py
📚 Learning: 2026-01-12T07:25:31.685Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1631
File: src/transform/thread_storage_sync.cc:1126-1137
Timestamp: 2026-01-12T07:25:31.685Z
Learning: In TileLang's thread storage synchronization pass (src/transform/thread_storage_sync.cc), at the IR level where PointerAccessIsDisjoint is called, the threads array in AccessEntry is guaranteed to contain all three thread dimensions (threadIdx.x, threadIdx.y, threadIdx.z), making access to the last 3 elements via `threads[threads.size() + idx - 3]` safe.

Applied to files:

  • src/op/atomic_add.cc
📚 Learning: 2025-12-15T07:23:50.065Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: tilelang/contrib/cutedsl/cpasync.py:45-55
Timestamp: 2025-12-15T07:23:50.065Z
Learning: In tilelang/contrib/cutedsl/cpasync.py, using AddressSpace.generic for TMA descriptor pointers (tensormap_ptr) in the extract_tensormap_ptr function is correct. When creating ptr_type with _cute_ir.PtrType.get for TMA descriptors in CuTeDSL, AddressSpace.generic should be used, not a device-specific or constant address space.

Applied to files:

  • src/tl_templates/cuda/copy_sm90.h
🧬 Code graph analysis (4)
src/op/atomic_add.h (2)
src/op/atomic_add.cc (2)
  • ComputeLinearLayout (387-402)
  • ComputeLinearLayout (387-387)
src/op/copy.cc (2)
  • ComputeLinearLayout (341-356)
  • ComputeLinearLayout (341-341)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (4)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)
  • flashattn_bwd_postprocess (133-163)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
  • flashattn_bwd_postprocess (179-196)
examples/flash_attention/example_mha_bwd_bshd.py (1)
  • flashattn_bwd_postprocess (129-147)
examples/flash_attention/example_mha_bwd_bhsd.py (1)
  • flashattn_bwd_postprocess (132-150)
testing/python/language/test_tilelang_language_atomic_add.py (4)
tilelang/language/allocate.py (1)
  • alloc_shared (39-54)
tilelang/language/annotations.py (1)
  • annotate_layout (27-40)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (62-71)
tilelang/language/atomic.py (1)
  • atomic_add (120-227)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (3)
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py (1)
  • flashattn_bwd_postprocess (198-228)
examples/flash_attention/example_mha_bwd_bshd.py (1)
  • flashattn_bwd_postprocess (129-147)
examples/flash_attention/example_mha_bwd_bhsd.py (1)
  • flashattn_bwd_postprocess (132-150)
🔇 Additional comments (6)
examples/autodd/tilelang_buggy.py (1)

76-79: Formatting changes look good.

Minor reformatting of method bodies to single-line returns and whitespace adjustments. No functional changes.

Also applies to: 106-107

examples/autodd/tilelang_minimized_expected.py (1)

46-46: LGTM!

Minor formatting change (quote style). No functional impact.

testing/python/language/test_tilelang_language_atomic_add.py (1)

369-373: Good addition of float16/bfloat16 coverage.

Extending test_atomic_different_memory_orders to cover additional data types improves test coverage for the atomic operations.

src/tl_templates/cuda/copy_sm90.h (1)

265-331: New TMA descriptor-based tma_store_add overloads are correct.

The 1D-5D variants follow the same pattern as existing tma_store overloads and correctly:

  • Use "l" constraint for 64-bit descriptor and "r" for 32-bit values
  • Include "memory" clobber for proper ordering
  • Omit cache hints (consistent with the existing tma_store_add and optional per PTX spec for cp.reduce.async.bulk.tensor)
  • Use the correct syntax: cp.reduce.async.bulk.tensor.{1-5}d.global.shared::cta.add.bulk_group
src/op/atomic_add.cc (2)

516-728: TMA lowering path implementation looks reasonable overall.

The TMA-based lowering correctly:

  • Builds the TMA descriptor with shape/stride information
  • Handles swizzle layout detection
  • Computes shared memory offsets
  • Splits operations when inner_box_dim exceeds instruction_dim
  • Guards TMA operations with thread predicate (line 727)

The architecture follows the same patterns as existing TMA copy operations, which aids maintainability.


52-68: This mapping is not an inconsistency—CUDA's CUtensorMapDataType enum does not define INT8 or INT16 variants, only UINT8 and UINT16. The code correctly uses the only available options for 8-bit and 16-bit integers. Additionally, TMA descriptors are for bulk load/store operations; atomic operations use separate CUDA atomic APIs that are independent of the data type descriptor.

Likely an incorrect or invalid review comment.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +443 to +449
int dim = shared_tensor->shape.size();
const int64_t mat_stride = *as_const_int(shared_tensor->shape[dim - 2]);
const int64_t mat_continuous =
*as_const_int(shared_tensor->shape[dim - 1]);
Layout swizzle_layout =
makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous,
shared_tensor->dtype.bits(), /*k_inner=*/true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential null pointer dereference when shape dimensions are not constant.

as_const_int returns nullptr if the expression is not a compile-time constant. Dereferencing without a null check will crash if shared_tensor->shape contains symbolic dimensions.

🐛 Proposed fix
       int dim = shared_tensor->shape.size();
-      const int64_t mat_stride = *as_const_int(shared_tensor->shape[dim - 2]);
-      const int64_t mat_continuous =
-          *as_const_int(shared_tensor->shape[dim - 1]);
+      auto mat_stride_ptr = as_const_int(shared_tensor->shape[dim - 2]);
+      auto mat_continuous_ptr = as_const_int(shared_tensor->shape[dim - 1]);
+      if (!mat_stride_ptr || !mat_continuous_ptr) {
+        // Cannot determine swizzle layout for non-constant shape; use linear layout
+        result_map.Set(shared_tensor, ComputeLinearLayout(shared_tensor));
+        return result_map;
+      }
+      const int64_t mat_stride = *mat_stride_ptr;
+      const int64_t mat_continuous = *mat_continuous_ptr;
       Layout swizzle_layout =
           makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous,
                                  shared_tensor->dtype.bits(), /*k_inner=*/true);
🤖 Prompt for AI Agents
In `@src/op/atomic_add.cc` around lines 443 - 449, The code dereferences
as_const_int results for shared_tensor->shape elements (used to set mat_stride
and mat_continuous) without checking for nullptr, which can crash for symbolic
dimensions; update the block around mat_stride/mat_continuous and the call to
makeGemmABLayoutHopper to first call as_const_int(...) for each dimension,
validate the returned pointer is non-null, and handle the non-constant case
(e.g., choose a safe default, skip swizzle_layout creation, or fall back to a
non-swizzled Layout) so you never dereference a null pointer when constructing
swizzle_layout via makeGemmABLayoutHopper.

Comment on lines +649 to +655
for (const auto &check : swizzle_checks) {
if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) {
LOG(WARNING) << "AtomicAdd TMA cannot support swizzled layout with "
"inner_box_dim_ > "
<< check.max_dim;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Swizzle constraint violation only warns, doesn't prevent execution.

When inner_box_dim_ exceeds the swizzle's maximum dimension, only a warning is logged but execution continues. This could lead to silent data corruption or undefined behavior with TMA hardware.

🐛 Proposed fix: Either fail or fall back to no swizzle
     for (const auto &check : swizzle_checks) {
       if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) {
-        LOG(WARNING) << "AtomicAdd TMA cannot support swizzled layout with "
-                        "inner_box_dim_ > "
-                     << check.max_dim;
+        LOG(WARNING) << "AtomicAdd TMA: inner_box_dim_ (" << inner_box_dim_
+                     << ") exceeds max (" << check.max_dim
+                     << ") for swizzle type, falling back to SWIZZLE_NONE";
+        desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
+        break;
       }
     }

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one minor issue

…function to utils.cc and utils.h, while removing redundant definitions from atomic_add.cc and copy.cc.
@LeiWang1999 LeiWang1999 merged commit 732971a into tile-ai:main Jan 14, 2026
6 checks passed
tzj-fxz pushed a commit to wfloveiu/tilelang that referenced this pull request Jan 14, 2026
* support cp.reduce.async.bulk.tensor and add test

* Refactor flash attention example by removing unnecessary layout annotations

* support swizzle layout for tma reduce

* auto swizzle for non-1d tma atomic add

* upd example and test

* lint

* typo

* add constraint for test

* Refactor CUDA data type mapping by moving the to_CUtensorMapDataType function to utils.cc and utils.h, while removing redundant definitions from atomic_add.cc and copy.cc.

* lint
@Rachmanino Rachmanino deleted the tma-reduce branch January 14, 2026 07:06
LeiWang1999 added a commit that referenced this pull request Jan 26, 2026
* finish KDA algorithm in tilelang

* fix pre-commit.ci

* fix pre-commit.ci

* fix pre-commit local

* [Style] Fix some code styles

* [Refactor] Remove redundant swizzle for they can be automatically done

* [Refactor] remove chunk_bwd_intra.py and rename chunk_bwd_intra_op.py and do some fix form coderabbitai

* update ruff

* update pre-commit

* [Enhancement] Improve unroll loop functionality for dynamic extent and corresponding test case (#1654)

* Add unroll loop functionality and corresponding test case

- Introduced a new `UnrollLoop` function in the transform module to unroll loops based on various configuration options.
- Added a test case in `test_tilelang_language_unroll.py` to validate the behavior of `T.unroll` with only the extent parameter, ensuring correct kernel generation with unroll pragmas.

* Refactor unroll kernel implementation and update test case

- Changed the kernel function in `test_tilelang_language_unroll.py` to use a new `unroll_kernel` function that compiles and returns the output tensor, improving clarity and structure.
- Updated the `OptimizeForTarget` function in `phase.py` to ensure the `UnrollLoop` transformation is applied correctly, maintaining consistency in optimization phases.

* lint fix

* lint fix

* [Bugfix] Fix missing annotations for default CallNode Visitor (#1659)

tvm fix

* [Clean] Remove unnecessary debug print (#1661)

remove unnecessary debug print

* [Bugfix] Fix variable scoping issue in InjectSoftwarePipeline for transitive LetStmt dependencies (#1657)

* [Enhancement] Update global load/store functions for CUDA compatibility (#1652)

Refactor the `ld_global_256` and `st_global_256` functions to support both CUDA versions above 12.9 and earlier versions. This change ensures that 256-bit loads and stores are handled correctly across different CUDA versions, improving performance and compatibility. The implementation now uses two 128-bit loads/stores for older versions, enhancing the robustness of the codebase.

* Update comments in global load/store functions for CUDA compatibility

Clarified comments in `ld_global_256` and `st_global_256` functions to indicate that the fallback for CUDA versions below 12.9 may have performance regressions. This change enhances code readability and provides better context for developers working with different CUDA versions.

* Update submodule and enhance LetStmt handling in inject_pipeline.cc

- Updated the TVM submodule to the latest commit.
- Improved the handling of LetStmt in the inject_pipeline.cc file to account for transitive dependencies on loop variables, ensuring correct variable substitution in rewritten blocks.
- Adjusted test_tilelang_issue_1263.py to remove unnecessary jit decorator and updated the kernel compilation process with specific pass configurations.

* lint fix

* revert tvm

* remove unused test

* test fix

* [Refactor] Improve CallNode handling to include annotations in various operations (#1663)

* [Enhancement] Update CallNode handling to include annotations in various operations

- Modified CallNode invocations in multiple files to ensure that annotations are passed correctly, enhancing the consistency and functionality of the codebase.
- Removed the "use_tma" annotation from AtomicAddNode and adjusted related calls to maintain expected behavior.
- Updated CUDA intrinsic dispatch functions to include annotations, improving compatibility and correctness in CUDA operations.

* lint fix

* [EagerJIT] Add Support for Parameter Only Kernel Compilation (#1664)

* [Fix] Refactor type hint extraction logic in DSLMutator for better clarity and handling of annotations

* [Refactor] Remove redundant tensor creation in loop layout tests and update kernel compilation parameters

* [AutoDD] Add Tilelang AutoDD to Reduce Buggy Program (#1639)

* [Feat] Add tilelang autodd for delta debugging

* fix typos

* fix lint error

* fix typos

* fix lint error

* fix bugs

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix codeview comments

* [Refactor] Move AutoDD detection to env module and update import logic

* Refactor: Relocate the _is_running_autodd function to the env module for better organization and encapsulation.
* Update initialization logic to skip logger and heavy imports based on a new light import mode, enhancing flexibility in module usage.
* Ensure consistent handling of environment variables across the package, improving overall code clarity and maintainability.

* [Documentation] Add AutoDD section to debug_tools_for_tilelang.md

* Introduced a comprehensive guide on AutoDD (Automatic Delta Debugging) for isolating bugs in TileLang programs.
* Explained Delta Debugging methodology, usage, parameters, and provided examples for clarity.
* Highlighted the benefits of using AutoDD for large codebases and hard-to-locate errors, emphasizing time-saving aspects.
* Included tips for effective usage and a reference to a complete example in the documentation.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: kurisu6912 <227995639+kurisu6912@users.noreply.github.com>
Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* rebase origin

* [Feature] Support `cp.reduce.async.bulk.tensor` (#1667)

* support cp.reduce.async.bulk.tensor and add test

* Refactor flash attention example by removing unnecessary layout annotations

* support swizzle layout for tma reduce

* auto swizzle for non-1d tma atomic add

* upd example and test

* lint

* typo

* add constraint for test

* Refactor CUDA data type mapping by moving the to_CUtensorMapDataType function to utils.cc and utils.h, while removing redundant definitions from atomic_add.cc and copy.cc.

* lint

* rename basename according to CI

* Update submodule TVM and remove deprecated KDA example files

- Updated the TVM submodule to commit 354eef9a.
- Removed several outdated KDA example files and utility scripts that are no longer in use, including chunk_bwd_dqkwg.py, chunk_bwd_dv.py, chunk_bwd_gla_dA.py, chunk_bwd_intra.py, chunk_delta_bwd.py, chunk_delta_h_fwd.py, chunk_inter_solve_fused.py, chunk_intra_token_parallel.py, chunk_o.py, README.md, test_utils_kda.py, wy_fast_bwd.py, wy_fast.py, and various FLA_KDA implementations.

* lint fix

---------

Co-authored-by: wufang <wufang@MBP-MK6VR66Y2M-2329.local>
Co-authored-by: tzj-fxz <tzjfxz@gmail.com>
Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: Kuris <227995639+kurisu6912@users.noreply.github.com>
Co-authored-by: Kexing Zhou <KEKE_046@pku.edu.cn>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com>
Co-authored-by: Tong WU <109033598+Rachmanino@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] TMA atomic add's res is wrong when smem is swizzled

2 participants